#! /usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import division, print_function, absolute_import
from timeit import time
import warnings
import cv2
import numpy as np
from yoloTiny import YOLOTiny

import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

from deep_sort import preprocessing
from deep_sort import nn_matching
from deep_sort.detection import Detection
from deep_sort.detection_yolo import Detection_YOLO
from deep_sort.tracker import Tracker
from tools import generate_detections_tiny as gdet
import os

warnings.filterwarnings('ignore')


def main():
    # Definition of the parameters
    max_cosine_distance = 0.3
    nn_budget = None
    nms_max_overlap = 1.0

    # Deep SORT
    work_dir = os.path.dirname(os.path.abspath(__file__))
    model_filename = os.path.join(work_dir, 'model_data/yolo/mars-small128.pb')
    encoder = gdet.create_box_encoder(model_filename, batch_size=1)

    metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
    tracker = Tracker(metric)

    # 以下代码必须在这里，也是神奇，要不然模型预测会出错
    config = ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True, per_process_gpu_memory_fraction=0.1))
    # config.gpu_options.allow_growth = True
    # config.gpu_options.per_process_gpu_memory_fraction=0.02
    session = InteractiveSession(config=config)

    saved_model_loaded = tf.saved_model.load(os.path.join(work_dir, 'model_data/yolov4-tiny-416'),
                                             tags=[tag_constants.SERVING])
    infer = saved_model_loaded.signatures['serving_default']
    # 必须在这里的代码结束
    yoloNew = YOLOTiny(['Car', 'Bus', 'Truck'], infer)

    tracking = True

    file_path = 'E:/Video/video(9).MP4'
    video_capture = cv2.VideoCapture(file_path)

    fps = 0.0

    startTime = time.time()
    count = 0
    print("start")
    while True:
        # ttTime=time.time()
        ret, frame = video_capture.read()
        # print("read():\t"+str(time.time()-ttTime))
        # ttTime=time.time()
        if ret != True:
            break
        if count == 1:
            startTime = time.time()
        count = count + 1
        t1 = time.time()

        detectionFrame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        # print("Image.fromarray:\t"+str(time.time()-ttTime))
        # ttTime=time.time()
        bboxes, scores, names = yoloNew.detect_image(detectionFrame)
        # print("detect_image:\t"+str(time.time()-ttTime))
        # ttTime=time.time()
        if tracking:
            features = encoder(frame, bboxes)
            detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in
                          zip(bboxes, scores, names, features)]
        else:
            detections = [Detection_YOLO(bbox, confidence, cls) for bbox, confidence, cls in zip(bboxes, scores, names)]

        # Run non-maxima suppression.
        boxs = np.array([d.tlwh for d in detections])
        scores = np.array([d.confidence for d in detections])
        classes = np.array([d.class_name for d in detections])
        indices = preprocessing.non_max_suppression(boxs, nms_max_overlap, scores)
        detections = [detections[i] for i in indices]
        # print("non-maxima suppression:\t"+str(time.time()-ttTime))
        # ttTime=time.time()
        if tracking:
            # Call the tracker
            tracker.predict()
            tracker.update(detections)

            for track in tracker.tracks:
                if not track.is_confirmed() or track.time_since_update > 1:
                    continue
                bbox = track.to_tlbr()
                cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255, 255, 255), 2)
                cv2.putText(frame, "ID: " + str(track.track_id), (int(bbox[0]), int(bbox[1])), 0,
                            1.5e-3 * frame.shape[0], (0, 255, 0), 1)

        for det in detections:
            bbox = det.to_tlbr()
            score = "%.2f" % round(det.confidence * 100, 2) + "%"
            cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (255, 0, 0), 2)
            if len(classes) > 0:
                cls = det.class_name
                cv2.putText(frame, str(cls) + " " + score, (int(bbox[0]), int(bbox[3])), 0,
                            1.5e-3 * frame.shape[0], (0, 255, 0), 1)

        cv2.imshow('', frame)

        fps = (fps + (1. / (time.time() - t1))) / 2
        print("FPS = %f" % (fps))

        # Press Q to stop!
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
        # print("show:\t"+str(time.time()-ttTime))
        # ttTime=time.time()
    fps = ((count - 1) * 1.0 / (time.time() - startTime))
    print("ALLFPS = %f" % (fps))
    # fps_imutils.stop()
    # print('imutils FPS: {}'.format(fps_imutils.fps()))
    video_capture.release()
    cv2.destroyAllWindows()


if __name__ == '__main__':
    main()
# main(YOLO(['Car']))
